{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jaosjY4rGRNH"
      },
      "source": [
        "# Installing NeMo from source\n",
        "\n",
        "\n",
        "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
        "\n",
        "Instructions for setting up Colab are as follows:\n",
        "1. Open a new Python 3 notebook.\n",
        "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
        "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
        "4. Run the cell below to set up dependencies.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "goQzOSflEq27"
      },
      "outputs": [],
      "source": [
        "import os \n",
        "!apt-get update && apt-get install -y libsndfile1 ffmpeg\n",
        "!git clone https://github.com/NVIDIA/NeMo --branch main\n",
        "os.chdir('NeMo')\n",
        "!./reinstall.sh\n",
        "os.chdir('..')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GjQ_z_xQMDIb"
      },
      "source": [
        "# Overview\n",
        "\n",
        "There are three tasks as part of this tutorial\n",
        "\n",
        "1. Intent and Slot Classification using Assistant Dataset and a BERT model\n",
        "2. Intent Classification using Schema Guided Dialogue Dataset and a GPT2 model\n",
        "3. Answer Extender using MS Marco NLGen Dataset and a BART model\n",
        "\n",
        "Feel free to skip to the task that interests you most after installing NeMo from source."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AS-zwy8tEq2_"
      },
      "source": [
        "# 1. Intent and Slot Classification using Assistant Dataset\n",
        "\n",
        "## 1.1 Task Description\n",
        "\n",
        "**Joint Intent and Slot classification** - is a task of classifying an Intent and detecting all relevant Slots (Entities)\n",
        "for this Intent in a query.\n",
        "For example, in the query:  `What is the weather in Santa Clara tomorrow morning?`, we would like to classify the query\n",
        "as a `weather` Intent, and detect `Santa Clara` as a `location` slot and `tomorrow morning` as a `date_time` slot.\n",
        "Intents and Slots names are usually task specific and defined as labels in the training data.\n",
        "This is a fundamental step that is executed in any task-driven Conversational Assistant.\n",
        "\n",
        "Our model enables to train and then detect both of these tasks together.\n",
        "\n",
        "Note: There is a similar model available at [Joint Intent Slot Classification Colab](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/nlp/Joint_Intent_and_Slot_Classification.ipynb). However, this model only support BERT style models while the model in this tutorial supports other types of models such as GPT2. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FJk_UAyeEq3B"
      },
      "source": [
        "\n",
        "## 1.2 Download Assistant dataset and convert to NeMo format\n",
        "\n",
        "This is a virtual assistant interaction data set that can be downloaded from here: https://github.com/xliuhw/NLU-Evaluation-Data.\n",
        "There are about 10K training and 1K testing queries which cover 64 various Intents and 55 Slots. \n",
        "\n",
        "An example is:\n",
        "\n",
        "* utterance: what alarms have i set for tomorrow intent: \n",
        "* intent: alarm_query\n",
        "* slots: date(tomorrow)\n",
        "\n",
        "\n",
        "Note: While only the assistant dataset is used here, import_dataset.py is also compatible with ATIS and SNIPS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jjOVdGX2Eq3D"
      },
      "outputs": [],
      "source": [
        "# download and unzip the example dataset from github\n",
        "!wget https://github.com/xliuhw/NLU-Evaluation-Data/archive/master.zip\n",
        "!unzip master.zip\n",
        "# convert the dataset to the NeMo format\n",
        "!python NeMo/examples/nlp/intent_slot_classification/data/import_datasets.py --dataset_name=assistant --source_data_dir=./NLU-Evaluation-Data-master --target_data_dir=./assistant\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5n81deZsEq3G"
      },
      "source": [
        "## 1.3 Training and/or Testing the model\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eoYc_8jhEq3G"
      },
      "outputs": [],
      "source": [
        "# model.dataset.data_dir: folder to load data from\n",
        "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n",
        "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n",
        "  do_training=True \\\n",
        "  model.dataset.data_dir='./assistant' \\\n",
        "  model.dataset.dialogues_example_dir='./assistant_bert_examples' \\\n",
        "  model.dataset.task='assistant' \\\n",
        "  model.language_model.pretrained_model_name='bert-base-uncased' \\\n",
        "  exp_manager.create_wandb_logger=False)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GaPmHjayEbg8"
      },
      "source": [
        "**Results after 3 epochs**\n",
        "\n",
        "Intent report: \n",
        "```\n",
        "    label                                                precision    recall       f1           support   \n",
        "    alarm_query (label_id: 0)                              100.00      94.44      97.14         18\n",
        "    alarm_remove (label_id: 1)                             100.00      90.91      95.24         11\n",
        "    alarm_set (label_id: 2)                                 94.12      94.12      94.12         17\n",
        "    audio_volume_down (label_id: 3)                         75.00      42.86      54.55          7\n",
        "    audio_volume_mute (label_id: 4)                        100.00      92.86      96.30         14\n",
        "    audio_volume_up (label_id: 5)                           72.22     100.00      83.87         13\n",
        "    calendar_query (label_id: 6)                            87.50      77.78      82.35         18\n",
        "    calendar_remove (label_id: 7)                           94.44     100.00      97.14         17\n",
        "    calendar_set (label_id: 8)                              94.44      94.44      94.44         18\n",
        "    cooking_recipe (label_id: 9)                            85.71      70.59      77.42         17\n",
        "    datetime_convert (label_id: 10)                         88.89     100.00      94.12          8\n",
        "    datetime_query (label_id: 11)                           89.47     100.00      94.44         17\n",
        "    email_addcontact (label_id: 12)                         80.00     100.00      88.89          8\n",
        "    email_query (label_id: 13)                             100.00      83.33      90.91         18\n",
        "    email_querycontact (label_id: 14)                       78.95      88.24      83.33         17\n",
        "    email_sendemail (label_id: 15)                          94.44      94.44      94.44         18\n",
        "    general_affirm (label_id: 16)                          100.00     100.00     100.00         17\n",
        "    general_commandstop (label_id: 17)                     100.00     100.00     100.00         18\n",
        "    general_confirm (label_id: 18)                         100.00     100.00     100.00         17\n",
        "    general_dontcare (label_id: 19)                        100.00     100.00     100.00         18\n",
        "    general_explain (label_id: 20)                         100.00     100.00     100.00         17\n",
        "    general_joke (label_id: 21)                             91.67     100.00      95.65         11\n",
        "    general_negate (label_id: 22)                          100.00     100.00     100.00         18\n",
        "    general_praise (label_id: 23)                          100.00     100.00     100.00         17\n",
        "    general_quirky (label_id: 24)                           60.00      50.00      54.55         18\n",
        "    general_repeat (label_id: 25)                          100.00     100.00     100.00         17\n",
        "    iot_cleaning (label_id: 26)                            100.00     100.00     100.00         15\n",
        "    iot_coffee (label_id: 27)                               85.71     100.00      92.31         18\n",
        "    iot_hue_lightchange (label_id: 28)                     100.00      94.12      96.97         17\n",
        "    iot_hue_lightdim (label_id: 29)                        100.00     100.00     100.00         12\n",
        "    iot_hue_lightoff (label_id: 30)                        100.00     100.00     100.00         17\n",
        "    iot_hue_lighton (label_id: 31)                         100.00      50.00      66.67          4\n",
        "    iot_hue_lightup (label_id: 32)                          84.62      91.67      88.00         12\n",
        "    iot_wemo_off (label_id: 33)                            100.00     100.00     100.00          9\n",
        "    iot_wemo_on (label_id: 34)                             100.00      85.71      92.31          7\n",
        "    lists_createoradd (label_id: 35)                        90.00     100.00      94.74         18\n",
        "    lists_query (label_id: 36)                             100.00      94.12      96.97         17\n",
        "    lists_remove (label_id: 37)                             88.89      88.89      88.89         18\n",
        "    music_likeness (label_id: 38)                          100.00      93.75      96.77         16\n",
        "    music_query (label_id: 39)                             100.00     100.00     100.00         17\n",
        "    music_settings (label_id: 40)                           77.78     100.00      87.50          7\n",
        "    news_query (label_id: 41)                               72.73      88.89      80.00         18\n",
        "    play_audiobook (label_id: 42)                          100.00     100.00     100.00         17\n",
        "    play_game (label_id: 43)                                93.75      83.33      88.24         18\n",
        "    play_music (label_id: 44)                               85.00     100.00      91.89         17\n",
        "    play_podcasts (label_id: 45)                           100.00      88.89      94.12         18\n",
        "    play_radio (label_id: 46)                               84.21      94.12      88.89         17\n",
        "    qa_currency (label_id: 47)                              85.00      94.44      89.47         18\n",
        "    qa_definition (label_id: 48)                            89.47     100.00      94.44         17\n",
        "    qa_factoid (label_id: 49)                               64.00      88.89      74.42         18\n",
        "    qa_maths (label_id: 50)                                 84.62      84.62      84.62         13\n",
        "    qa_stock (label_id: 51)                                 87.50      77.78      82.35         18\n",
        "    recommendation_events (label_id: 52)                    87.50      82.35      84.85         17\n",
        "    recommendation_locations (label_id: 53)                 83.33      83.33      83.33         18\n",
        "    recommendation_movies (label_id: 54)                   100.00      60.00      75.00         10\n",
        "    social_post (label_id: 55)                             100.00      94.12      96.97         17\n",
        "    social_query (label_id: 56)                            100.00      82.35      90.32         17\n",
        "    takeaway_order (label_id: 57)                           92.31      70.59      80.00         17\n",
        "    takeaway_query (label_id: 58)                           93.75      83.33      88.24         18\n",
        "    transport_query (label_id: 59)                          81.25      76.47      78.79         17\n",
        "    transport_taxi (label_id: 60)                          100.00     100.00     100.00         16\n",
        "    transport_ticket (label_id: 61)                         85.00      94.44      89.47         18\n",
        "    transport_traffic (label_id: 62)                        93.75      88.24      90.91         17\n",
        "    weather_query (label_id: 63)                            89.47     100.00      94.44         17\n",
        "    -------------------\n",
        "    micro avg                                               91.16      91.16      91.16        996\n",
        "    macro avg                                               91.66      90.44      90.48        996\n",
        "    weighted avg                                            91.72      91.16      91.04        996\n",
        "```\n",
        "Slot report: \n",
        "```\n",
        "    label                                                precision    recall       f1           support   \n",
        "    alarm_type (label_id: 0)                                 0.00       0.00       0.00          2\n",
        "    app_name (label_id: 1)                                   0.00       0.00       0.00          1\n",
        "    artist_name (label_id: 2)                               17.39      80.00      28.57          5\n",
        "    audiobook_author (label_id: 3)                           0.00       0.00       0.00          0\n",
        "    audiobook_name (label_id: 4)                            64.52      74.07      68.97         27\n",
        "    business_name (label_id: 5)                             81.48      84.62      83.02         52\n",
        "    business_type (label_id: 6)                             80.00      80.00      80.00         20\n",
        "    change_amount (label_id: 7)                             57.14      66.67      61.54          6\n",
        "    coffee_type (label_id: 8)                              100.00      33.33      50.00          3\n",
        "    color_type (label_id: 9)                                75.00      92.31      82.76         13\n",
        "    cooking_type (label_id: 10)                              0.00       0.00       0.00          1\n",
        "    currency_name (label_id: 11)                           100.00      96.43      98.18         28\n",
        "    date (label_id: 12)                                     87.88      87.22      87.55        133\n",
        "    definition_word (label_id: 13)                          85.00      85.00      85.00         20\n",
        "    device_type (label_id: 14)                              84.75      76.92      80.65         65\n",
        "    drink_type (label_id: 15)                                0.00       0.00       0.00          0\n",
        "    email_address (label_id: 16)                            64.29     100.00      78.26          9\n",
        "    email_folder (label_id: 17)                            100.00      50.00      66.67          2\n",
        "    event_name (label_id: 18)                               80.00      75.00      77.42         64\n",
        "    food_type (label_id: 19)                                84.38      77.14      80.60         35\n",
        "    game_name (label_id: 20)                                93.55      78.38      85.29         37\n",
        "    game_type (label_id: 21)                                 0.00       0.00       0.00          0\n",
        "    general_frequency (label_id: 22)                         0.00       0.00       0.00          9\n",
        "    house_place (label_id: 23)                              80.95      91.89      86.08         37\n",
        "    ingredient (label_id: 24)                                0.00       0.00       0.00          1\n",
        "    joke_type (label_id: 25)                               100.00     100.00     100.00          5\n",
        "    list_name (label_id: 26)                                89.29      69.44      78.12         36\n",
        "    meal_type (label_id: 27)                                 0.00       0.00       0.00          3\n",
        "    media_type (label_id: 28)                               78.95      83.33      81.08         36\n",
        "    movie_name (label_id: 29)                                0.00       0.00       0.00          1\n",
        "    movie_type (label_id: 30)                                0.00       0.00       0.00          0\n",
        "    music_album (label_id: 31)                               0.00       0.00       0.00          0\n",
        "    music_descriptor (label_id: 32)                          0.00       0.00       0.00          2\n",
        "    music_genre (label_id: 33)                              81.82      90.00      85.71         10\n",
        "    news_topic (label_id: 34)                               80.00      30.77      44.44         13\n",
        "    order_type (label_id: 35)                              100.00      42.11      59.26         19\n",
        "    person (label_id: 36)                                   70.79     100.00      82.89         63\n",
        "    personal_info (label_id: 37)                            76.19      94.12      84.21         17\n",
        "    place_name (label_id: 38)                               82.86      84.47      83.65        103\n",
        "    player_setting (label_id: 39)                           75.00      42.86      54.55          7\n",
        "    playlist_name (label_id: 40)                             0.00       0.00       0.00          3\n",
        "    podcast_descriptor (label_id: 41)                       92.31      54.55      68.57         22\n",
        "    podcast_name (label_id: 42)                             66.67      16.67      26.67         12\n",
        "    radio_name (label_id: 43)                               94.87      94.87      94.87         39\n",
        "    relation (label_id: 44)                                 90.91      90.91      90.91         11\n",
        "    song_name (label_id: 45)                               100.00       6.67      12.50         15\n",
        "    time (label_id: 46)                                     77.57      84.69      80.98         98\n",
        "    time_zone (label_id: 47)                                44.44     100.00      61.54          4\n",
        "    timeofday (label_id: 48)                                86.96      80.00      83.33         25\n",
        "    transport_agency (label_id: 49)                         80.00      57.14      66.67          7\n",
        "    transport_descriptor (label_id: 50)                      0.00       0.00       0.00          5\n",
        "    transport_name (label_id: 51)                            0.00       0.00       0.00          0\n",
        "    transport_type (label_id: 52)                           88.89     100.00      94.12         40\n",
        "    weather_descriptor (label_id: 53)                       87.50      87.50      87.50          8\n",
        "    O (label_id: 54)                                        97.07      97.52      97.30       5408\n",
        "    -------------------\n",
        "    micro avg                                               94.24      94.24      94.24       6582\n",
        "    macro avg                                               64.87      59.93      59.17       6582\n",
        "    weighted avg                                            94.23      94.24      93.95       6582\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1.4 (Optional) To train/ test a GPT2 model on the assistant dataset, run the cell below "
      ],
      "metadata": {
        "id": "-44x5PqyrOeQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# model.dataset.data_dir: folder to load data from\n",
        "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n",
        "# model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\": gpt2 doesn't specify a pad token, therefore using its EOS token as the pad token\n",
        "# model.dataset.target_template=with_slots: this perform slot filling with intent classification\n",
        "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n",
        "  do_training=True \\\n",
        "  model.dataset.data_dir='./assistant' \\\n",
        "  model.dataset.dialogues_example_dir='./assistant_gpt2_examples' \\\n",
        "  model.dataset.task='assistant' \\\n",
        "  model.language_model.pretrained_model_name='gpt2' \\\n",
        "  trainer.max_epochs=1 \\\n",
        "  model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\" \\\n",
        "  model.dataset.target_template=with_slots \\\n",
        "  model.dataset.eval_mode=generation \\\n",
        "  exp_manager.create_wandb_logger=False)"
      ],
      "metadata": {
        "id": "QyqQbpR4rNHT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**After 1 epoch:**\n",
        "\n",
        "More epochs would be helpful\n",
        "\n",
        "Intent report:\n",
        "\n",
        "  ```\n",
        "  label                                                precision    recall       f1           support   \n",
        "    transport query (label_id: 0)                           72.73      84.21      78.05         19\n",
        "    weather query (label_id: 1)                             94.74      94.74      94.74         19\n",
        "    play game (label_id: 2)                                 92.86      68.42      78.79         19\n",
        "    qa currency (label_id: 3)                              100.00     100.00     100.00         19\n",
        "    qa maths (label_id: 4)                                 100.00     100.00     100.00         14\n",
        "    iot wemo off (label_id: 5)                              75.00     100.00      85.71          9\n",
        "    datetime convert (label_id: 6)                          46.67      87.50      60.87          8\n",
        "    email addcontact (label_id: 7)                          70.00      87.50      77.78          8\n",
        "    music likeness (label_id: 8)                            57.89      61.11      59.46         18\n",
        "    music query (label_id: 9)                               78.57      57.89      66.67         19\n",
        "    general negate (label_id: 10)                           95.00     100.00      97.44         19\n",
        "    email sendemail (label_id: 11)                          92.86      68.42      78.79         19\n",
        "    general affirm (label_id: 12)                           95.00     100.00      97.44         19\n",
        "    play audiobook (label_id: 13)                           57.69      78.95      66.67         19\n",
        "    general praise (label_id: 14)                          100.00      94.74      97.30         19\n",
        "    alarm set (label_id: 15)                                85.71      94.74      90.00         19\n",
        "    general explain (label_id: 16)                         100.00      89.47      94.44         19\n",
        "    iot wemo on (label_id: 17)                              83.33      71.43      76.92          7\n",
        "    cooking recipe (label_id: 18)                           90.00      94.74      92.31         19\n",
        "    music settings (label_id: 19)                           60.00      42.86      50.00          7\n",
        "    social post (label_id: 20)                              84.21      84.21      84.21         19\n",
        "    recommendation events (label_id: 21)                    72.73      84.21      78.05         19\n",
        "    audio volume up (label_id: 22)                          76.47     100.00      86.67         13\n",
        "    lists remove (label_id: 23)                             73.08     100.00      84.44         19\n",
        "    transport ticket (label_id: 24)                         94.74      94.74      94.74         19\n",
        "    general joke (label_id: 25)                            100.00     100.00     100.00         12\n",
        "    play podcasts (label_id: 26)                            94.12      84.21      88.89         19\n",
        "    iot hue lightchange (label_id: 27)                      85.71      63.16      72.73         19\n",
        "    audio volume mute (label_id: 28)                        84.62      73.33      78.57         15\n",
        "    general dontcare (label_id: 29)                         95.00     100.00      97.44         19\n",
        "    qa definition (label_id: 30)                            77.27      89.47      82.93         19\n",
        "    email querycontact (label_id: 31)                       58.33      73.68      65.12         19\n",
        "    general commandstop (label_id: 32)                     100.00     100.00     100.00         19\n",
        "    calendar remove (label_id: 33)                          94.44      89.47      91.89         19\n",
        "    news query (label_id: 34)                              100.00      57.89      73.33         19\n",
        "    calendar query (label_id: 35)                           63.16      63.16      63.16         19\n",
        "    social query (label_id: 36)                             88.24      83.33      85.71         18\n",
        "    transport traffic (label_id: 37)                        90.48     100.00      95.00         19\n",
        "    transport taxi (label_id: 38)                          100.00      94.44      97.14         18\n",
        "    alarm query (label_id: 39)                             100.00      94.74      97.30         19\n",
        "    iot hue lightoff (label_id: 40)                         88.89      84.21      86.49         19\n",
        "    takeaway order (label_id: 41)                           81.25      68.42      74.29         19\n",
        "    iot coffee (label_id: 42)                              100.00      94.74      97.30         19\n",
        "    recommendation movies (label_id: 43)                    75.00      90.00      81.82         10\n",
        "    iot hue lightup (label_id: 44)                          78.57      78.57      78.57         14\n",
        "    email query (label_id: 45)                              85.71      94.74      90.00         19\n",
        "    lists createoradd (label_id: 46)                        82.35      73.68      77.78         19\n",
        "    play radio (label_id: 47)                               84.21      84.21      84.21         19\n",
        "    audio volume down (label_id: 48)                       100.00      87.50      93.33          8\n",
        "    general quirky (label_id: 49)                           30.00      15.79      20.69         19\n",
        "    play music (label_id: 50)                               71.43      52.63      60.61         19\n",
        "    qa stock (label_id: 51)                                 90.48     100.00      95.00         19\n",
        "    iot cleaning (label_id: 52)                             93.33      87.50      90.32         16\n",
        "    iot hue lightdim (label_id: 53)                        100.00     100.00     100.00         12\n",
        "    recommendation locations (label_id: 54)                100.00      89.47      94.44         19\n",
        "    general repeat (label_id: 55)                          100.00     100.00     100.00         19\n",
        "    takeaway query (label_id: 56)                           77.27      89.47      82.93         19\n",
        "    alarm remove (label_id: 57)                            100.00     100.00     100.00         11\n",
        "    datetime query (label_id: 58)                           75.00      63.16      68.57         19\n",
        "    iot hue lighton (label_id: 59)                          60.00     100.00      75.00          3\n",
        "    qa factoid (label_id: 60)                               50.00      57.89      53.66         19\n",
        "    calendar set (label_id: 61)                             75.00      78.95      76.92         19\n",
        "    general confirm (label_id: 62)                         100.00     100.00     100.00         19\n",
        "    lists query (label_id: 63)                              66.67      73.68      70.00         19\n",
        "    label_id: 64                                             0.00       0.00       0.00          0\n",
        "    -------------------\n",
        "    micro avg                                               83.55      83.55      83.55       1076\n",
        "    macro avg                                               83.53      83.93      83.01       1076\n",
        "    weighted avg                                            84.26      83.55      83.30       1076\n",
        "    \n",
        "```\n",
        "\n",
        "```\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "       Test metric             DataLoader 0\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "        intent_f1            83.55018615722656\n",
        "    intent_precision         83.55018615722656\n",
        "      intent_recall          83.55018615722656\n",
        "         slot_f1             73.99985919756773\n",
        "slot_joint_goal_accuracy     65.89219330855019\n",
        "     slot_precision          73.85223048327137\n",
        "       slot_recall           74.14807930607186\n",
        "  test_intent_accuracy       83.55018587360595\n",
        "     test_loss_epoch       0.019178826361894608\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "```"
      ],
      "metadata": {
        "id": "FbQ-6TVM1yQg"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gd42arYoEq3J"
      },
      "source": [
        "# 2. Schema Guided Dialogue (SGD)\n",
        "\n",
        "## 2.1 Task Description\n",
        "---\n",
        "\n",
        "SGD is a multi-domain intent classification dataset from Google with close to 100k examples.\n",
        "\n",
        "An example is:\n",
        "\n",
        "* utterance: I will be eating there at 11:30 am so make it for then.\n",
        "* intent: ReserveRestaurant\n",
        "* slots: {\"time\": \"11:30 am\"}\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "neH8rXwjEq3J"
      },
      "source": [
        "## 2.2 Download the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IgD8eavfJ5pi"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/google-research-datasets/dstc8-schema-guided-dialogue.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7G7uPrUpEq3J"
      },
      "source": [
        "## 2.3 Training and/or Testing the model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gqo-rwQlEq3K"
      },
      "outputs": [],
      "source": [
        "# model.dataset.data_dir: folder to load data from\n",
        "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n",
        "# model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\": gpt2 doesn't specify a pad token, therefore using its EOS token as the pad token\n",
        "\n",
        "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n",
        "  do_training=True \\\n",
        "  model.dataset.data_dir='./dstc8-schema-guided-dialogue' \\\n",
        "  model.dataset.dialogues_example_dir='./sgd_gpt2_predictions' \\\n",
        "  model.dataset.task='sgd' \\\n",
        "  model.language_model.pretrained_model_name='gpt2' \\\n",
        "  trainer.max_epochs=1 \\\n",
        "  model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\" \\\n",
        "  exp_manager.create_wandb_logger=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!ls sgd_gpt2_predictions"
      ],
      "metadata": {
        "id": "kGDlV5HvI2PQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p8g0f5KDTu9K"
      },
      "source": [
        "**After 1 epoch:**\n",
        "\n",
        "More epoches would needed to reach convergence.\n",
        "\n",
        "\n",
        "```\n",
        "    label                                                precision    recall       f1           support   \n",
        "    check balance (label_id: 0)                              0.00       0.00       0.00          0\n",
        "    find trains (label_id: 1)                               80.20      91.95      85.68        348\n",
        "    make payment (label_id: 2)                              83.12      28.07      41.97        228\n",
        "    book appointment (label_id: 3)                          86.93      87.15      87.04        397\n",
        "    get cars available (label_id: 4)                        96.88      90.51      93.58        274\n",
        "    get event dates (label_id: 5)                            0.00       0.00       0.00          0\n",
        "    buy bus ticket (label_id: 6)                            78.61      91.33      84.49        173\n",
        "    add event (label_id: 7)                                  0.00       0.00       0.00          0\n",
        "    get alarms (label_id: 8)                                58.33      77.78      66.67         45\n",
        "    reserve car (label_id: 9)                               83.75      72.43      77.68        185\n",
        "    get events (label_id: 10)                                0.00       0.00       0.00          0\n",
        "    reserve roundtrip flights (label_id: 11)                 0.00       0.00       0.00          0\n",
        "    lookup music (label_id: 12)                             89.83      86.89      88.33         61\n",
        "    book house (label_id: 13)                               91.13      92.50      91.81        200\n",
        "    search oneway flight (label_id: 14)                     74.77      47.70      58.25        174\n",
        "    buy event tickets (label_id: 15)                        72.19      95.31      82.15        128\n",
        "    find apartment (label_id: 16)                            0.00       0.00       0.00          0\n",
        "    schedule visit (label_id: 17)                           77.27      66.06      71.23        386\n",
        "    play media (label_id: 18)                               92.94      86.81      89.77         91\n",
        "    get ride (label_id: 19)                                 99.41      98.82      99.12        170\n",
        "    reserve oneway flight (label_id: 20)                     0.00       0.00       0.00          0\n",
        "    find bus (label_id: 21)                                 96.64      87.53      91.86        361\n",
        "    find restaurants (label_id: 22)                         77.14      91.22      83.59        148\n",
        "    get times for movie (label_id: 23)                       0.00       0.00       0.00          0\n",
        "    transfer money (label_id: 24)                            0.00       0.00       0.00          0\n",
        "    request payment (label_id: 25)                          46.71      63.39      53.79        112\n",
        "    play movie (label_id: 26)                              100.00      65.11      78.87        321\n",
        "    search house (label_id: 27)                             97.91      91.83      94.77        306\n",
        "    search roundtrip flights (label_id: 28)                 67.49      82.41      74.21        199\n",
        "    find provider (label_id: 29)                            95.11      90.53      92.77        602\n",
        "    find attractions (label_id: 30)                        100.00      89.01      94.19         91\n",
        "    reserve hotel (label_id: 31)                            56.75      97.04      71.62        169\n",
        "    lookup song (label_id: 32)                               0.00       0.00       0.00          0\n",
        "    add alarm (label_id: 33)                                95.68      60.18      73.89        221\n",
        "    find home by area (label_id: 34)                        48.95      59.79      53.83        194\n",
        "    get available time (label_id: 35)                        0.00       0.00       0.00          0\n",
        "    buy movie tickets (label_id: 36)                       100.00      29.39      45.42        473\n",
        "    reserve restaurant (label_id: 37)                       95.71      84.80      89.92        342\n",
        "    find movies (label_id: 38)                              62.40      97.61      76.14        335\n",
        "    get weather (label_id: 39)                             100.00      87.69      93.44        195\n",
        "    search hotel (label_id: 40)                             99.35      52.60      68.78        289\n",
        "    find events (label_id: 41)                              99.57      82.56      90.27        281\n",
        "    play song (label_id: 42)                                 0.00       0.00       0.00          0\n",
        "    rent movie (label_id: 43)                                0.00       0.00       0.00          0\n",
        "    get train tickets (label_id: 44)                        45.83       5.56       9.91        198\n",
        "    none (label_id: 45)                                     55.77      98.90      71.32        728\n",
        "    label_id: 46                                             0.00       0.00       0.00          0\n",
        "    -------------------\n",
        "    micro avg                                               77.23      77.23      77.23       8425\n",
        "    macro avg                                               82.01      76.68      76.56       8425\n",
        "    weighted avg                                            83.23      77.23      76.86       8425\n",
        "\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jUJb-9VLLBXo"
      },
      "source": [
        "# 3. MS Marco\n",
        "\n",
        "## Task Description\n",
        "\n",
        "MS Marco NLGen is a dataset from Microsoft that takes extracted answers and questions and output fluent answers.\n",
        "\n",
        "An example is \n",
        "\n",
        "\n",
        "*   question: what county is nine mile in\n",
        "*   extracted_answer: Onondaga\n",
        "*   fluent_answer: Nine Mile is in Onondaga county.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VtXEKG_UQU9u"
      },
      "source": [
        "## Download and unzip files"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b9avsZ1CEq3K"
      },
      "outputs": [],
      "source": [
        "!mkdir ms_marco\n",
        "os.chdir('ms_marco')\n",
        "!wget https://msmarco.blob.core.windows.net/msmarco/train_v2.1.json.gz\n",
        "!wget https://msmarco.blob.core.windows.net/msmarco/dev_v2.1.json.gz\n",
        "\n",
        "!gunzip train_v2.1.json.gz\n",
        "!gunzip dev_v2.1.json.gz\n",
        "\n",
        "!python ../NeMo/examples/nlp/dialogue/remove_ms_marco_samples_without_wellFormedAnswers.py --filename train_v2.1.json \n",
        "!python ../NeMo/examples/nlp/dialogue/remove_ms_marco_samples_without_wellFormedAnswers.py --filename dev_v2.1.json \n",
        "\n",
        "os.chdir('..')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h7UZ9R8gQTFo"
      },
      "source": [
        "## Training and/or Testing the model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fwGQCwbvRf2m"
      },
      "outputs": [],
      "source": [
        "# model.dataset.data_dir: folder to load data from\n",
        "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n",
        "\n",
        "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n",
        "  do_training=True \\\n",
        "  model.dataset.dialogues_example_dir='./marco_bart_predictions' \\\n",
        "  model.dataset.data_dir='./ms_marco' \\\n",
        "  model.save_model=True \\\n",
        "  model.dataset.debug_mode=True \\\n",
        "  model.dataset.task='ms_marco' \\\n",
        "  model.language_model.pretrained_model_name='facebook/bart-base' \\\n",
        "  trainer.max_epochs=1 \\\n",
        "  model.dataset.debug_mode=False \\\n",
        "  exp_manager.create_wandb_logger=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UL7ekAOZ2abi"
      },
      "source": [
        "**After 1 epoch:**\n",
        "\n",
        "Train more epoches for optimal performance\n",
        "\n",
        "```\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "       Test metric             DataLoader 0\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "          bleu               65.46179962158203\n",
        "           f1                78.24439835896995\n",
        "        precision            81.92473076099847\n",
        "         recall              76.72508929408436\n",
        "      test_accuracy         25.563487607283225\n",
        "        test_loss           0.4419259166606655\n",
        "     test_loss_epoch        0.4420809745788574\n",
        "        test_ppl            1.5557004846779854\n",
        "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
        "```"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Dialogue.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}